{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import cv2\n",
    "import torch\n",
    "import numpy as np\n",
    "import PIL\n",
    "from PIL import Image\n",
    "from einops import rearrange\n",
    "from video_vae import CausalVideoVAELossWrapper\n",
    "from torchvision import transforms as pth_transforms\n",
    "from torchvision.transforms.functional import InterpolationMode\n",
    "from IPython.display import Image as ipython_image\n",
    "from diffusers.utils import load_image, export_to_video, export_to_gif\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"pyramid-flow-miniflux/causal_video_vae\"   # The video-vae checkpoint dir\n",
    "model_dtype = 'bf16'\n",
    "\n",
    "device_id = 3\n",
    "torch.cuda.set_device(device_id)\n",
    "\n",
    "model = CausalVideoVAELossWrapper(\n",
    "    model_path,\n",
    "    model_dtype,\n",
    "    interpolate=False, \n",
    "    add_discriminator=False,\n",
    ")\n",
    "model = model.to(\"cuda\")\n",
    "\n",
    "if model_dtype == \"bf16\":\n",
    "    torch_dtype = torch.bfloat16 \n",
    "elif model_dtype == \"fp16\":\n",
    "    torch_dtype = torch.float16\n",
    "else:\n",
    "    torch_dtype = torch.float32\n",
    "\n",
    "def image_transform(images, resize_width, resize_height):\n",
    "    transform_list = pth_transforms.Compose([\n",
    "        pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n",
    "        pth_transforms.ToTensor(),\n",
    "        pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "    return torch.stack([transform_list(image) for image in images])\n",
    "\n",
    "\n",
    "def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n",
    "    transform_list = []\n",
    "\n",
    "    if resize:\n",
    "        if new_width is None:\n",
    "            new_width = width // 8 * 8\n",
    "        if new_height is None:\n",
    "            new_height = height // 8 * 8\n",
    "        transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n",
    "    \n",
    "    transform_list.extend([\n",
    "        pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "    ])\n",
    "    transform_list = pth_transforms.Compose(transform_list)\n",
    "\n",
    "    return transform_list\n",
    "\n",
    "\n",
    "def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n",
    "    try:\n",
    "        video_capture = cv2.VideoCapture(video_path)\n",
    "        fps = video_capture.get(cv2.CAP_PROP_FPS)\n",
    "        frames = []\n",
    "        pil_frames = []\n",
    "        while True:\n",
    "            flag, frame = video_capture.read()\n",
    "            if not flag:\n",
    "                break\n",
    "    \n",
    "            pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n",
    "            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "            frame = torch.from_numpy(frame)\n",
    "            frame = frame.permute(2, 0, 1)\n",
    "            frames.append(frame)\n",
    "            if len(frames) >= max_frames:\n",
    "                break\n",
    "\n",
    "        video_capture.release()\n",
    "        interval = max(int(fps / sample_fps), 1)\n",
    "        pil_frames = pil_frames[::interval][:frame_number]\n",
    "        frames = frames[::interval][:frame_number]\n",
    "        frames = torch.stack(frames).float() / 255\n",
    "        width = frames.shape[-1]\n",
    "        height = frames.shape[-2]\n",
    "        video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n",
    "        frames = video_transform(frames)\n",
    "        pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n",
    "\n",
    "        if resize:\n",
    "            if new_width is None:\n",
    "                new_width = width // 32 * 32\n",
    "            if new_height is None:\n",
    "                new_height = height // 32 * 32\n",
    "            pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n",
    "        return frames, pil_frames\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "\n",
    "def show_video(ori_path, rec_path, width=\"100%\"):\n",
    "    html = ''\n",
    "    if ori_path is not None:\n",
    "        html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
    "        <source src=\"{ori_path}\" type=\"video/mp4\">\n",
    "        </video>\n",
    "        \"\"\"\n",
    "    \n",
    "    html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
    "    <source src=\"{rec_path}\" type=\"video/mp4\">\n",
    "    </video>\n",
    "    \"\"\"\n",
    "    return HTML(html)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image Reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = 'image_path'\n",
    "\n",
    "image = Image.open(image_path).convert(\"RGB\")\n",
    "resize_width = image.width // 8 * 8\n",
    "resize_height = image.height // 8 * 8\n",
    "input_image_tensor = image_transform([image], resize_width, resize_height)\n",
    "input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
    "    latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n",
    "    rec_images = model.decode_latent(latent)\n",
    "\n",
    "display(image)\n",
    "display(rec_images[0])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Video Reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'video_path'\n",
    "\n",
    "frame_number = 57   # x*8 + 1\n",
    "width = 640\n",
    "height = 384\n",
    "\n",
    "video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n",
    "video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
    "print(video_frames_tensor.shape)\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
    "    latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n",
    "    rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n",
    "\n",
    "export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n",
    "export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n",
    "show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
